# (这份代码跑起来会报错，仅提供思路参考)
import os
from dotenv import load_dotenv, find_dotenv # 导入 find_dotenv 帮助定位
from langchain.prompts import PromptTemplate
from langchain.chains import ConversationChain
from langchain.chains import LLMChain
from langchain.llms import OpenAI
from langchain.chains.router.llm_router import LLMRouterChain,RouterOutputParser
from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE
from langchain.chains.router import MultiPromptChain

# 加载 .env 文件中的环境变量 (增强调试)
load_dotenv(dotenv_path=find_dotenv(usecwd=True), verbose=True, override=True)

# 从环境变量加载 API 密钥和基础 URL
api_key = os.getenv("OPENAI_API_KEY")
api_base = os.getenv("OPENAI_API_BASE")
os.environ["OPENAI_API_KEY"] = api_key
os.environ["OPENAI_API_BASE"] = api_base

# 使用MultiPromptChain来决定具体使用哪条链路，有默认链路
#物理链
physics_template = """您是一位非常聪明的物理教授.\n
您擅长以简洁易懂的方式回答物理问题.\n
当您不知道问题答案的时候，您会坦率承认不知道.\n
下面是一个问题:
{input}"""
physics_prompt = PromptTemplate.from_template(physics_template)

#数学链
math_template = """您是一位非常优秀的数学教授.\n
您擅长回答数学问题.\n
您之所以如此优秀，是因为您能够将困难问题分解成组成的部分，回答这些部分，然后将它们组合起来，回答更广泛的问题.\n
下面是一个问题:
{input}"""
math_prompt = PromptTemplate.from_template(math_template)

prompt_infos = [
    {
        "name":"physics",
        "description":"擅长回答物理问题",
        "prompt_template":physics_template,
    },
    {
        "name":"math",
        "description":"擅长回答数学问题",
        "prompt_template":math_template,
    },
]

llm = OpenAI(
    temperature = 0,
    model="gpt-3.5-turbo-instruct"
)
destination_chains = {}
for p_info in prompt_infos:
    name = p_info["name"]
    prompt_template = p_info["prompt_template"]
    prompt = PromptTemplate(
        template=prompt_template,
        input_variables=["input"]
    )
    chain = LLMChain(
        llm=llm,
        prompt=prompt,
    )
    destination_chains[name] = chain
default_chain = ConversationChain(
    llm = llm,
    output_key="text"
)

destinations = [f"{p['name']}:{p['description']}" for p in prompt_infos]
destinations_str = "\n".join(destinations)
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(destinations=destinations_str)
#print(MULTI_PROMPT_ROUTER_TEMPLATE)

router_prompt = PromptTemplate(
    template=router_template,
    input_variables=["input"],
    output_parser=RouterOutputParser()
)
router_chain = LLMRouterChain.from_llm(
    llm,
    router_prompt
)

chain = MultiPromptChain(
    router_chain=router_chain,
    destination_chains=destination_chains,
    default_chain=default_chain,
    verbose=True
)
chain.run("什么是牛顿第一定律?")